from __future__ import print_function
import torch
import torch.nn.functional as F
import argparse

from utils import load_data
from utils import random_random_crop
from utils import abs_coord_to_norm
from utils import norm_coord_to_abs
from utils import save_model
from utils import scheduler_step
from utils import random_o_crop

from network_and_loss import CoordNet
from network_and_loss import PWConLoss


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float

parser = argparse.ArgumentParser(description='CoordNet training with PWConLoss')

# Hyperparameters 
parser.add_argument('--task', type=str, default="COFW", help='Task dataset')
parser.add_argument('--mode', type=str, default="all", help='Mode of SupConLoss')
parser.add_argument('--random_scale', default=True, help='Whether to apply random flip')
parser.add_argument('--random_flip', default=True, help='Whether to apply random flip')
parser.add_argument('--random_rotation', default=False, help='Whether to apply random rotation')

parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
parser.add_argument('--num_epochs', type=int, default=2000, help='Maximum number of epochs')
parser.add_argument('--learning_rate_proj', type=float, default=1E-2, help='Model learning rate')
parser.add_argument('--learning_rate_head', type=float, default=1E-3, help='Model learning rate')
parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='Learning rate decay rate')
parser.add_argument('--lr_decay_interval', type=int, default=1000, help='Learning rate decay interval')

args = parser.parse_args()

def main():
    torch.cuda.empty_cache()
    train_loader, test_loader = load_data(args.task, args.batch_size, args.random_scale, args.random_flip, args.random_rotation)
    
    coordnet = CoordNet().to(device)
    
    optimizer = torch.optim.Adam([
        {'params': list(coordnet.conv.parameters())+list(coordnet.proj.parameters()), 'lr': args.learning_rate_proj},
        {'params': coordnet.head.parameters(), 'lr': args.learning_rate_head} ])
    
    criterion = PWConLoss(args.mode).to(device)
    
   
    for epoch in range(args.num_epochs):
        if epoch != 0 :
            scheduler_step(optimizer, epoch, args.lr_decay_interval, args.lr_decay_rate)
        
        pwconloss, mseloss = train(coordnet, train_loader, criterion, optimizer)
        
        regression_error_mean, regression_error_std = test(coordnet, test_loader)
        
        print("\nEpoch: {}/{}..".format(epoch+1, args.num_epochs).ljust(14),
              "PWConloss: {:.3f}.. ".format(pwconloss).ljust(12),
              "MSEloss: {:.3f}.. ".format(mseloss).ljust(12))   
        print("Regression_error_mean: {:.3f}.. ".format(regression_error_mean).ljust(12),
              "Regression_error_std: {:.3f} .. ".format(regression_error_std).ljust(12))
        
        # save model pth file
        if epoch % 500 == 499 : 
            save_model("CoordNet", coordnet, optimizer, epoch+1)
        
        
        
def train(coordnet, train_loader, criterion, optimizer):
    coordnet.train()
    total_pwconloss = 0
    total_mseloss = 0
    
    for i, (images, landmark_coords) in enumerate(train_loader) :
        images, landmark_coords = images.to(device), landmark_coords.to(device)
        B = images.size(0)
        
        random_view1, random_coords1, random_view2, random_coords2, r1_r1_distance, r1_r1_relationship, \
            r1_r2_distance, r1_r2_relationship, r2_r2_distance, r2_r2_relationship = random_random_crop(images)
        random1_random2 = torch.cat((random_view1, random_view2), 0)
    
        coordnet.zero_grad()
        optimizer.zero_grad()
        z, c = coordnet(random1_random2)
        
        z1, z2 = torch.split(z, [B, B], dim=0)
        projections = torch.cat((z1.unsqueeze(1), z2.unsqueeze(1)), dim=1)
        pwconloss = criterion(projections, r1_r1_distance, r1_r1_relationship, \
                              r1_r2_distance, r1_r2_relationship, r2_r2_distance, r2_r2_relationship)
        
        random_coords1_rel = abs_coord_to_norm(random_coords1)
        random_coords2_rel = abs_coord_to_norm(random_coords2)
        label_coords_rel = torch.cat((random_coords1_rel, random_coords2_rel), 0)
        mseloss = F.mse_loss(c, label_coords_rel)
        
        total_pwconloss += pwconloss.item() / len(train_loader)
        total_mseloss += mseloss.item() / len(train_loader)
        
        loss = pwconloss + 1000*mseloss
        loss.backward()
        optimizer.step()
        
    return total_pwconloss, total_mseloss



def test(coordnet, test_loader) : 
    with torch.no_grad() : 
        coordnet.eval()
        error = torch.zeros(1, 2).to(device)
        
        for i, (images, _) in enumerate(test_loader) : 
            images = images.to(device)
            
            o, crop_coords = random_o_crop(images, args.batch_size)
            _, c = coordnet(o)
            
            error_yx = torch.abs(norm_coord_to_abs(c) - crop_coords)
            error = torch.cat((error, error_yx), dim=0)
        
        error = error[1:]
        error = torch.sqrt(error[:, 0]**2 + error[:, 1]**2)
        error_mean = error.mean()
        error_std = error.std()
        
    return error_mean, error_std


if __name__=='__main__':
    main()
    
    
    
    
